#!/usr/bin/env python3

from PIL import Image
import sys

if len(sys.argv) != 3:
	sys.exit(f"Usage: {sys.argv[0]} <output> <reference>")

model = Image.open(sys.argv[1])
gold = Image.open(sys.argv[2])

if model.width < gold.width or model.height < gold.height:
	sys.exit("Model output is smaller than reference image!")

compwidth = min(model.width, gold.width)
compheight = min(model.height, gold.height)
invalid_mask = Image.new("RGB", (compwidth, compheight))
invalid_count = 0

def pixels_equal(pix0, pix1):
	return all((pix0[i] & 0xf8) == (pix1[i] & 0xf8) for i in range(min(3, len(pix0))))

for x in range(compwidth):
	for y in range(compheight):
		if not pixels_equal(model.getpixel((x, y)), gold.getpixel((x, y))):
			if invalid_count == 0:
				print(f"First invalid pixel encountered at ({x}, {y})")
			invalid_count += 1
			invalid_mask.putpixel((x, y), (255, 0, 0))

if invalid_count > 0:
	invalid_mask.save("invalid_mask.png")
	print("Invalid pixels written to file invalid_mask.png")
sys.exit(invalid_count)
